import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt

class FedQlearning_genb:
    def __init__(self, mdp, c1, c2, total_episodes, num_agents):
        self.mdp = mdp
        self.c1 = c1
        self.c2 = c2
        self.total_episodes = total_episodes
        self.num_agents = num_agents
        self.V_func = np.zeros((self.mdp.H+1, self.mdp.S),dtype = np.float32)
        self.trigger_times = 0
        self.comm_episode_collection = []
        self.V_sum_all = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.V2_sum_all = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)

        self.global_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)
        for i in range(self.mdp.H):
            self.global_Q[i,:,:] = self.mdp.H - i


        self.global_N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)

        self.agent_N = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        
        self.agent_V_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_V2_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.beta = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.regret = []
        self.cost = []

    def run_episode(self, agent_id):
        # Get the policy (actions for all states and steps)
        #V_func[h,s]
        event_triggered = False
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair

        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)

            # Increment visit count for the current state-action pair
            self.agent_N[agent_id, step, state, action] += 1
            self.agent_V_sum[agent_id, step, state, action] += self.V_func[step+1, next_state]
            self.agent_V2_sum[agent_id, step, state, action] += (self.V_func[step+1, next_state])**2
                
            # Store the received reward
            rewards[step, state, action] = reward
            # Check if the event-triggered condition is met

            flag = self.check_event_triggered(agent_id, step, state, action)
            if flag:
                event_triggered = True
            state = next_state
        return rewards, event_triggered, state_init

    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                best_action = np.argmax(self.global_Q[step, state])
                actions[step, state, best_action] = 1

        return actions


    def check_event_triggered(self, agent_id, step, state, action):
        # Calculate the threshold for triggering the event
        tilde_C = 1.0 / (self.mdp.H * (self.mdp.H + 1))
        global_visits = self.global_N[step, state, action]
        threshold = max(1, int(np.floor((tilde_C / self.num_agents) * global_visits)))

        # Check if the visit count exceeds the threshold
        return self.agent_N[agent_id, step, state, action] >= threshold

    def aggregate_data(self, policy_k, rewards):
        H, M = self.mdp.H, self.num_agents
        i_0 = 2 * M * H * (H + 1)
        for h in range(H):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):
                    #print(policy_k[h, s])
                    if a != np.argmax(policy_k[h, s]) or self.agent_N[:, h, s, a].sum() == 0:
                        # No update required, retain previous Q-values
                        continue
                    else:
                        # Calculate aggregated values
                        N_h_k = self.global_N[h, s, a]
                        n_h_k = self.agent_N[:, h, s, a].sum()

                        if N_h_k < i_0:
                            # Case 1: Update rule for small N_h_k (update Q sequentially)
                            t00 = N_h_k
                            alpha_agg_side = 1
                            for ag_id in range(self.num_agents):
                                if self.agent_N[ag_id, h, s, a] > 0:
                                    t00 = t00 + 1
                                    step_size = (H + 1) / (H + t00)
                                    alpha_agg_side = alpha_agg_side*(1 - step_size)                
                                    self.global_Q[h, s, a] = (1 - step_size) * self.global_Q[h, s, a] + step_size * (rewards[h, s, a] + self.agent_V_sum[ag_id, h, s, a])
                            self.global_N[h,s,a] += sum(self.agent_N[:,h,s,a])
                            self.V_sum_all[h,s,a] += sum(self.agent_V_sum[:, h, s, a]) 
                            self.V2_sum_all[h,s,a] += sum(self.agent_V2_sum[:, h, s, a])                                     
                            sigma2_v = max(self.V2_sum_all[h,s,a]/self.global_N[h, s, a] - (self.V_sum_all[h,s,a]/self.global_N[h, s, a])**2,0)
                            ucb_bonus = self.c1 * (H - h - 1) * np.sqrt(H / self.global_N[h, s, a])
                            betanew = min(self.c2*(np.sqrt(H*(sigma2_v+H)/self.global_N[h, s, a])+np.sqrt(H**7*self.mdp.S*self.mdp.A)/self.global_N[h, s, a]), 2*ucb_bonus)
                            bonus = betanew - alpha_agg_side * self.beta[h,s,a]
                            self.beta[h,s,a] = betanew
                            self.global_Q[h, s, a] += bonus/2
                            # +np.sqrt(M*H**6*self.mdp.S*self.mdp.A)/self.global_N[h, s, a])
                        else:
                            t00 = N_h_k
                            alpha_agg_side = 1.0
                            for i in range(n_h_k):
                                t00 = t00 + 1
                                step_size = (H + 1) / (H + t00)
                                alpha_agg_side = alpha_agg_side*(1 - step_size)
                            self.global_N[h,s,a] += sum(self.agent_N[:,h,s,a])
                            self.V_sum_all[h,s,a] += sum(self.agent_V_sum[:, h, s, a]) 
                            self.V2_sum_all[h,s,a] += sum(self.agent_V2_sum[:, h, s, a])                                     
                            sigma2_v = max(self.V2_sum_all[h,s,a]/self.global_N[h, s, a] - (self.V_sum_all[h,s,a]/self.global_N[h, s, a])**2,0)
                            ucb_bonus = self.c1 * (H - h - 1) * np.sqrt(H / self.global_N[h, s, a])
                            betanew = min(self.c2*(np.sqrt(H*(sigma2_v+H)/self.global_N[h, s, a])+np.sqrt(H**7*self.mdp.S*self.mdp.A)/self.global_N[h, s, a]), 2*ucb_bonus)
                            bonus = betanew - alpha_agg_side*self.beta[h,s,a]
                            self.beta[h,s,a] = betanew
                            self.global_Q[h, s, a] = alpha_agg_side * self.global_Q[h, s, a] + (1 - alpha_agg_side) * (rewards[h, s, a] + sum(self.agent_V_sum[:, h, s, a])/n_h_k) +bonus/2      
    
        # Reset the visit counts for each agent
        self.agent_N.fill(0)
        self.agent_V_sum.fill(0)
        self.agent_V2_sum.fill(0)


    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()
        # Event-triggered termination flag
        event_triggered = False
        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.V_func[h,s] = max(self.global_Q[h, s, :])
        actions_policy = self.choose_action()

        for episode in range(self.total_episodes):
            # Run one episode for each agent
            value = self.mdp.value_gen(actions_policy)
            for agent_id in range(self.num_agents):
                agent_reward, agent_event_triggered, state_init = self.run_episode(agent_id)
                self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
                self.regret.append(self.regret_cum)         

                for h in range(self.mdp.H):
                    for s in range(self.mdp.S):
                        a = np.argmax(actions_policy[h, s])
                        if rewards[h, s, a] == 0:
                            rewards[h, s, a] = agent_reward[h,s,a]

                if agent_event_triggered:
                    event_triggered = True

            # Calculate regret
            if event_triggered:
                self.trigger_times += 1
                self.comm_episode_collection.append(episode)
                self.aggregate_data(actions_policy, rewards)
                event_triggered = False
                actions_policy = self.choose_action()
                for h in range(self.mdp.H):
                    for s in range(self.mdp.S):
                        self.V_func[h,s] = min(self.mdp.H-h, max(self.global_Q[h, s, :]))
            self.cost.append(self.trigger_times)
        return best_Q, self.global_Q